#!/usr/bin/env python
# coding: utf-8

# In[1]:


import pandas as pd
import numpy as np
import random
import math
from sklearn.model_selection import train_test_split
import os

import import_ipynb
import SelectIndices as si

# Set the random seed
random.seed(10012)


# In[2]:


## Set up the simulation scope
dataset_names = ['eo', 'stwts'] # ['eo', 'news', 'stwts']
embed_types = ['cvec_pca16', 'cvec_nmf16', 'cvec_umap16', 'cvec_tsne16', 'bert', 'roberta', 'distil', 'glove6B', 'universal', 'lda100']
counts = [50, 100, 250, 500, 750, 1000, 1500, 2000, 2500, 3000]
#bootstrap_iters = [1,2]
bootstrap_iters = [1,2,3,4,5,6,7,8,9,10]


# In[3]:


## Set Difference helper
def Diff(li1, li2):
    return list(set(li1) - set(li2)) + list(set(li2) - set(li1))


# In[4]:


## Read in the data
eos_full = pd.read_csv("data/raw/" + 'eo_clean_full.csv', index_col=0)
max_obs_eo = len(eos_full)
print(max_obs_eo)
eo_test_set_size = math.ceil(max_obs_eo*0.2)
print(eo_test_set_size)

stwts_full = pd.read_csv("data/raw/" +'stwts_clean_full.csv', index_col=0)
max_obs_stwts = len(stwts_full)
print(max_obs_stwts)
stwts_test_set_size = math.ceil(max_obs_stwts*0.2)
print(stwts_test_set_size)


# In[5]:


## Select train/test obs for each of the bootstrap iters

eo_testset_list = []
for i in range(len(bootstrap_iters)):
    random.seed(i)
    eo_testset_list.append(random.sample(range(0,max_obs_eo), eo_test_set_size))


stwts_testset_list = []
for i in range(len(bootstrap_iters)):
    random.seed(i)
    stwts_testset_list.append(random.sample(range(0,max_obs_stwts), stwts_test_set_size))


# In[6]:


## Output the lists of train/test obs

eo_df = pd.DataFrame(eo_testset_list)
eo_df = eo_df.transpose()
eo_df.columns = ["iter" + str(x) for x in [1,2,3,4,5,6,7,8,9,10]]
eo_df.to_csv("data/output/eo_testset_list.csv", index = False)


# In[7]:




stwts_df = pd.DataFrame(stwts_testset_list)
stwts_df = stwts_df.transpose()
stwts_df.columns = ["iter" + str(x) for x in [1,2,3,4,5,6,7,8,9,10]]

os.getcwd()
stwts_df.to_csv("data/output/stwts_testset_list.csv", index = False)


# In[8]:


print(len(eo_testset_list))
print(len(stwts_testset_list))


# In[9]:


print(eo_testset_list[0])
print(len(eo_testset_list[0]))

print(eo_testset_list[1])
print(len(eo_testset_list[1]))


# ## Random pick

# In[10]:


for i in range(len(eo_testset_list)):
    indices_list = []
    for c in counts:
        remaining_obs = Diff(range(1, max_obs_eo), eo_testset_list[i])
        indices_list.append(random.sample(remaining_obs, c))
    with open("data/output/" +'indices_eo_random_iter' + str(i+1) + '.txt', 'w') as filehandle:
        filehandle.writelines("%s\n" % idl for idl in indices_list)

for i in range(len(stwts_testset_list)):
    indices_list = []
    for c in counts:
        remaining_obs = Diff(range(1, max_obs_stwts), stwts_testset_list[i])
        indices_list.append(random.sample(remaining_obs, c))
    with open("data/output/" +'indices_stwts_random_iter' + str(i+1) + '.txt', 'w') as filehandle:
        filehandle.writelines("%s\n" % idl for idl in indices_list)


# ## K-means Clustering

# In[11]:


for h in range(len(eo_testset_list)):
    for i in range(len(dataset_names)):
      for j in range(len(embed_types)):
        indices_list = []
        data = pd.read_csv("data/output/" +dataset_names[i] + '_' + embed_types[j] + '_full.csv', index_col=0)
        if dataset_names[i]=="eo":
            idx = Diff(range(1, max_obs_eo), eo_testset_list[h])
        elif dataset_names[i] == "stwts":
            idx = Diff(range(1, max_obs_stwts), stwts_testset_list[h])
        else:
            print("Error")
        data = data.iloc[idx]
        data = data.to_numpy()
        data = si.torch.from_numpy(data).float()
        for c in counts:
          indices_list.append(si.kmeans_indices(data, c))
        with open("data/output/" +'indices_'+dataset_names[i]+'_'+embed_types[j]+'_kmeans_iter' + str(h+1) + '.txt', 'w') as filehandle:
          filehandle.writelines("%s\n" % idl for idl in indices_list)
        print('Completed.')


# ## Greedy farthest points based on KL Divergence

# In[12]:


for h in range(len(eo_testset_list)):
    for i in range(len(dataset_names)):
      for j in range(len(embed_types)):
        if dataset_names[i]=="eo":
            idx = Diff(range(1, max_obs_eo), eo_testset_list[h])
        elif dataset_names[i] == "stwts":
            idx = Diff(range(1, max_obs_stwts), stwts_testset_list[h])
        else:
            print("Error")      
        kld_matrix = np.load("data/output/" +dataset_names[i] + '_kld_' + embed_types[j] + '.npy')
        kld_matrix = kld_matrix[idx,:]
        kld_matrix = kld_matrix[:,idx]
        #print(kld_matrix.shape)
        indices_list = list(si.farthestPointSampler(kld_matrix, max(counts)))
        with open("data/output/" +'indices_'+dataset_names[i]+'_'+embed_types[j]+'_kld_iter' + str(h+1) + '.txt', 'w') as filehandle:
          filehandle.writelines("%s" % indices_list)
        print('Completed.')


# ## Greedy Farthest points based on Kolmogorov-Smirnov measure

# In[13]:


for h in range(len(eo_testset_list)):
    for i in range(len(dataset_names)):
      for j in range(len(embed_types)):
        if dataset_names[i]=="eo":
            idx = Diff(range(1, max_obs_eo), eo_testset_list[h])
        elif dataset_names[i] == "stwts":
            idx = Diff(range(1, max_obs_stwts), stwts_testset_list[h])
        else:
            print("Error")      
        ks_matrix = np.load("data/output/" +dataset_names[i] + '_ks_' + embed_types[j] + '.npy')
        ks_matrix = ks_matrix[idx,:]
        ks_matrix = ks_matrix[:,idx]
        #print(ks_matrix.shape)
        indices_list = list(si.farthestPointSampler(ks_matrix, max(counts)))
        with open("data/output/" +'indices_'+dataset_names[i]+'_'+embed_types[j]+'_ks_iter' + str(h+1) + '.txt', 'w') as filehandle:
          filehandle.writelines("%s" % indices_list)
        print('Completed.')


# ## Greedy Farthest points based on Cosine Distance

# In[14]:


for h in range(len(eo_testset_list)):
    for i in range(len(dataset_names)):
      for j in range(len(embed_types)):
        if dataset_names[i]=="eo":
            idx = Diff(range(1, max_obs_eo), eo_testset_list[h])
        elif dataset_names[i] == "stwts":
            idx = Diff(range(1, max_obs_stwts), stwts_testset_list[h])
        else:
            print("Error")      
        cos_matrix = np.load("data/output/" +dataset_names[i] + '_cos_' + embed_types[j] + '.npy')
        cos_matrix = cos_matrix[idx,:]
        cos_matrix = cos_matrix[:,idx]
        #print(cos_matrix.shape)
        indices_list = list(si.farthestPointSampler(cos_matrix, max(counts)))
        with open("data/output/" +'indices_'+dataset_names[i]+'_'+embed_types[j]+'_cos_iter' + str(h+1) + '.txt', 'w') as filehandle:
          filehandle.writelines("%s" % indices_list)
        print('Completed.')


# ## D-Optimality (Taddy)

# In[15]:


for h in range(len(eo_testset_list)):
    for i in range(len(dataset_names)):
      for j in range(len(embed_types)):
        indices_list = []
        data = pd.read_csv("data/output/" +dataset_names[i] + '_' + embed_types[j] + '_full.csv', index_col=0)
        if dataset_names[i]=="eo":
            idx = Diff(range(1, max_obs_eo), eo_testset_list[h])
        elif dataset_names[i] == "stwts":
            idx = Diff(range(1, max_obs_stwts), stwts_testset_list[h])
        else:
            print("Error")
        data = data.iloc[idx]
        data = data.to_numpy()
        data = si.torch.from_numpy(data).float()
        for c in counts:
          indices_list.append(si.dopt(data, c))
        with open("data/output/" +'indices_'+dataset_names[i]+'_'+embed_types[j]+'_dopt_iter' + str(h+1) + '.txt', 'w') as filehandle:
          filehandle.writelines("%s\n" % idl for idl in indices_list)
        print('Completed.')

